{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import re\n",
    "import os\n",
    "import gensim\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import string\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.tokenize import sent_tokenize, word_tokenize, RegexpTokenizer\n",
    "import json\n",
    "import time\n",
    "from collections import Counter\n",
    "from tqdm import tqdm"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "trainfilename = '../../data/train.tsv'\n",
    "validfilename = '../../data/valid.tsv'\n",
    "testfilename = '../../data/personalized_test.tsv'\n",
    "docsfilename = '../../data/news.tsv'\n",
    "stop_words = set(stopwords.words('english'))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "WORD_FREQ_THRESHOLD = 3\n",
    "MAX_CONTENT_LEN = 500\n",
    "MAX_BODY_LEN = 100\n",
    "MAX_TITLE_LEN = 16\n",
    "WORD_EMBEDDING_DIM = 300\n",
    "MAX_CLICK_LEN = 50\n",
    "\n",
    "word2freq = {}\n",
    "word2index = {}"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "def word_tokenize(sent):\n",
    "    pat = re.compile(r'[\\w]+|[.,!?;|]')\n",
    "    if isinstance(sent, str):\n",
    "        return pat.findall(sent.lower())\n",
    "    else:\n",
    "        return []"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "def read_news(filename,filer_num=3):\n",
    "    news={}\n",
    "    category, subcategory=[], []\n",
    "    news_index={}\n",
    "    index=1\n",
    "    word_cnt=Counter()\n",
    "    err = 0\n",
    "    news_data = pd.read_csv(filename, sep='\\t')\n",
    "    news_data.fillna(value=\" \", inplace=True)\n",
    "    for i in tqdm(range(len(news_data))):\n",
    "        doc_id,vert,_, title, snipplet= news_data.loc[i,:][:5]\n",
    "        news_index[doc_id]=index\n",
    "        index+=1\n",
    "\n",
    "        title = title.lower()\n",
    "        title = word_tokenize(title)\n",
    "        snipplet = snipplet.lower()\n",
    "        snipplet = word_tokenize(snipplet)\n",
    "        category.append(vert)\n",
    "        news[doc_id] = [vert,title,snipplet]     \n",
    "        word_cnt.update(snipplet+title)\n",
    "    # 0: pad; 1: <sos>; 2: <eos>\n",
    "    word = [k for k , v in word_cnt.items() if v >= filer_num]\n",
    "    word_dict = {k:v for k, v in zip(word, range(3,len(word)+3))}\n",
    "    category=list(set(category))\n",
    "    category_dict={k:v for k, v in zip(category, range(1,len(category)+1))}\n",
    "\n",
    "    return news,news_index,category_dict,word_dict"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "%time news,news_index,category_dict,word_dict = read_news(docsfilename)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "100%|██████████| 113762/113762 [01:04<00:00, 1754.25it/s]\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 1min 6s, sys: 4.8 s, total: 1min 11s\n",
      "Wall time: 1min 11s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "word_dict['unk'] = 0\n",
    "word_dict['<sos>'] = 1\n",
    "word_dict['<eos>'] = 2"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "with open('../../data2/dict.pkl', 'wb') as f:\n",
    "    pickle.dump([news_index,category_dict,word_dict], f)\n",
    "with open('../../data2/news.pkl', 'wb') as f:\n",
    "    pickle.dump(news, f)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## get inputs for user encoder"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "source": [
    "def get_rep_for_userencoder(news,news_index,category_dict,word_dict):\n",
    "    news_num=len(news)+1\n",
    "    news_title=np.zeros((news_num,MAX_TITLE_LEN),dtype='int32')\n",
    "    news_body=np.zeros((news_num,MAX_BODY_LEN),dtype='int32')\n",
    "    news_vert=np.zeros((news_num),dtype='int32')\n",
    "    for key in news:    \n",
    "        vert,title,body=news[key]\n",
    "        doc_index=news_index[key]\n",
    "        news_vert[doc_index] = category_dict[vert]\n",
    "        counter = 0\n",
    "        for word_id in range(min(MAX_TITLE_LEN,len(title))):\n",
    "            if title[word_id] in word_dict:\n",
    "                news_title[doc_index,counter]=word_dict[title[word_id].lower()]\n",
    "                counter += 1\n",
    "        counter = 0\n",
    "        for word_id in range(min(MAX_BODY_LEN,len(body))):\n",
    "            if body[word_id] in word_dict:\n",
    "                news_body[doc_index,counter]=word_dict[body[word_id].lower()]\n",
    "                counter += 1\n",
    "    return news_vert, news_title, news_body"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "source": [
    "%time news_vert, news_title, news_body = get_rep_for_userencoder(news,news_index,category_dict,word_dict)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 4.99 s, sys: 56 ms, total: 5.04 s\n",
      "Wall time: 5.04 s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "source": [
    "len(news_vert),len(news_title), len(news_body)"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(113763, 113763, 113763)"
      ]
     },
     "metadata": {},
     "execution_count": 11
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "source": [
    "np.save('../../data2/news_vert.npy', news_vert)\n",
    "np.save('../../data2/news_title.npy', news_title)\n",
    "np.save('../../data2/news_body.npy', news_body)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## get inputs/ targets for seq2seq model"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "source": [
    "def get_rep_for_seq2seq(news,news_index,word_dict):\n",
    "    news_num=len(news)+1\n",
    "    sources=np.zeros((news_num,MAX_CONTENT_LEN),dtype='int32')\n",
    "    target_inputs=np.zeros((news_num,MAX_TITLE_LEN),dtype='int32')\n",
    "    target_outputs=np.zeros((news_num,MAX_TITLE_LEN),dtype='int32')\n",
    "    for key in tqdm(news):    \n",
    "        _, title, body = news[key]\n",
    "        doc_index=news_index[key]\n",
    "        counter = 0\n",
    "        for word_id in range(min(MAX_CONTENT_LEN-1,len(body))):\n",
    "            if body[word_id] in word_dict:\n",
    "                sources[doc_index,counter]=word_dict[body[word_id].lower()]\n",
    "                counter += 1\n",
    "        sources[doc_index,counter] = 2 \n",
    "        \n",
    "        target_inputs[doc_index,0] = 1\n",
    "        counter = 1\n",
    "        for word_id in range(min(MAX_TITLE_LEN-1,len(title))):\n",
    "            if title[word_id] in word_dict:\n",
    "                target_inputs[doc_index,counter]=word_dict[title[word_id].lower()]\n",
    "                counter += 1\n",
    "        \n",
    "        counter = 0\n",
    "        for word_id in range(min(MAX_TITLE_LEN-1,len(title))):\n",
    "            if title[word_id] in word_dict:\n",
    "                target_outputs[doc_index,counter]=word_dict[title[word_id].lower()]\n",
    "                counter += 1\n",
    "        target_outputs[doc_index,counter] = 2\n",
    "        \n",
    "    return sources, target_inputs, target_outputs"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "%time sources, target_inputs, target_outputs = get_rep_for_seq2seq(news,news_index,word_dict)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "100%|██████████| 113762/113762 [00:18<00:00, 6243.60it/s]"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 17.9 s, sys: 440 ms, total: 18.3 s\n",
      "Wall time: 18.2 s\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "source": [
    "np.save('../../data2/sources.npy', sources)\n",
    "np.save('../../data2/target_inputs.npy', target_inputs)\n",
    "np.save('../../data2/target_outputs.npy', target_outputs)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## get embedding matrix"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "source": [
    "def load_matrix(embedding_path,word_dict):\n",
    "    mu, sigma = 0, 0.1\n",
    "    embedding_zero = np.zeros((1,300))\n",
    "    embedding_matrix = np.random.normal(mu, sigma, (len(word_dict)-1, WORD_EMBEDDING_DIM))\n",
    "    embedding_matrix = np.concatenate((embedding_zero,embedding_matrix))\n",
    "    have_word=[]\n",
    "    with open(os.path.join(embedding_path,'glove.840B.300d.txt'),'rb') as f:\n",
    "        while True:\n",
    "            l=f.readline()\n",
    "            if len(l)==0:\n",
    "                break\n",
    "            l=l.split()\n",
    "            word = l[0].decode()\n",
    "            if word in word_dict:\n",
    "                index = word_dict[word]\n",
    "                tp = [float(x) for x in l[1:]]\n",
    "                embedding_matrix[index]=np.array(tp)\n",
    "                have_word.append(word)\n",
    "    return embedding_matrix,have_word"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "source": [
    "%time embedding_matrix, have_word = load_matrix('../../data',word_dict)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 40.8 s, sys: 2.49 s, total: 43.3 s\n",
      "Wall time: 39 s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "source": [
    "len(word_dict),len(have_word)"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(141910, 100875)"
      ]
     },
     "metadata": {},
     "execution_count": 18
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "source": [
    "np.save('../../data2/embedding_matrix.npy', embedding_matrix)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## get train/ valid/ test examples from user logs"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "source": [
    "def Doc2ID(doclist,news2id):\n",
    "    return [news2id[i] for i in doclist if i in news2id ]"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "source": [
    "def PadDoc(doclist):\n",
    "    if len(doclist) >= MAX_CLICK_LEN:\n",
    "        return doclist[-MAX_CLICK_LEN:]\n",
    "    else:\n",
    "        return [0] * (MAX_CLICK_LEN-len(doclist)) + doclist[:MAX_CLICK_LEN]"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "source": [
    "def user2dict(users):\n",
    "    user_set = set(users)\n",
    "    user_dict = {k:v for k, v in zip(user_set, range(0,len(user_set)))}\n",
    "    return user_dict"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "source": [
    "def parse_train_user(filename,news_index):\n",
    "        \n",
    "    df = pd.read_csv(filename, sep='\\t')\n",
    "    df.fillna(value=\" \", inplace=True)\n",
    "    \n",
    "    df['ClicknewsID'] = df['ClicknewsID'].apply(lambda x: PadDoc(Doc2ID(x.split(),news_index)))\n",
    "    \n",
    "    df['pos']  = df['pos'].apply(lambda x: Doc2ID(x.split(),news_index))\n",
    "    df['neg'] = df['neg'].apply(lambda x: Doc2ID(x.split(),news_index))\n",
    "    \n",
    "    pos_neg_lists = []\n",
    "    for userindex, (pos_list, neg_list) in tqdm(enumerate(zip(df['pos'].values.tolist(), df['neg'].values.tolist()))):\n",
    "        if len(pos_list) and len(neg_list):\n",
    "            # sampling 1 negative sample for 1 pos sample\n",
    "            min_len = min(len(pos_list), len(neg_list))\n",
    "            np.random.shuffle(pos_list)\n",
    "            np.random.shuffle(neg_list)\n",
    "            for i in range(min_len):\n",
    "                pos_neg_lists.append([userindex, [pos_list[i],neg_list[i]],[1,0]])\n",
    "        \n",
    "    return df['ClicknewsID'].values.tolist(), pos_neg_lists"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "source": [
    "%time TrainUsers, TrainSamples = parse_train_user(trainfilename, news_index)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "400000it [00:07, 50885.87it/s]"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 35.5 s, sys: 2.85 s, total: 38.3 s\n",
      "Wall time: 34.7 s\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "source": [
    "with open('../../data2/TrainUsers.pkl', 'wb') as f:\n",
    "    pickle.dump(TrainUsers, f)\n",
    "with open('../../data2/TrainSamples.pkl', 'wb') as f:\n",
    "    pickle.dump(TrainSamples, f)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "source": [
    "def parse_valid_user(filename,news_index):\n",
    "        \n",
    "    df = pd.read_csv(filename, sep='\\t')\n",
    "    df.fillna(value=\" \", inplace=True)\n",
    "    \n",
    "    df['ClicknewsID'] = df['ClicknewsID'].apply(lambda x: PadDoc(Doc2ID(x.split(),news_index)))\n",
    "    \n",
    "    df['pos']  = df['pos'].apply(lambda x: Doc2ID(x.split(),news_index))\n",
    "    df['neg'] = df['neg'].apply(lambda x: Doc2ID(x.split(),news_index))\n",
    "    \n",
    "    pos_neg_lists = []\n",
    "    for userindex, (pos_list, neg_list) in enumerate(zip(df['pos'].values.tolist(), df['neg'].values.tolist())):\n",
    "        if len(pos_list) and len(neg_list):\n",
    "            pos_neg_lists.append([userindex, pos_list+neg_list,[1]*len(pos_list)+[0]*len(neg_list)])\n",
    "        \n",
    "    return df['ClicknewsID'].values.tolist(), pos_neg_lists"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "source": [
    "%time ValidUsers, ValidSamples = parse_valid_user(validfilename,news_index)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 10.6 s, sys: 916 ms, total: 11.5 s\n",
      "Wall time: 11.9 s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "source": [
    "with open('../../data2/ValidUsers.pkl', 'wb') as f:\n",
    "    pickle.dump(ValidUsers, f)\n",
    "with open('../../data2/ValidSamples.pkl', 'wb') as f:\n",
    "    pickle.dump(ValidSamples, f)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "source": [
    "def parse_test_user(filename,news_index):\n",
    "        \n",
    "    df = pd.read_csv(filename, sep='\\t')\n",
    "    \n",
    "    df['clicknewsID'] = df['clicknewsID'].apply(lambda x: PadDoc(Doc2ID(x.split(','),news_index)))\n",
    "    \n",
    "    df['posnewID']  = df['posnewID'].apply(lambda x: Doc2ID(x.split(','),news_index))\n",
    "    \n",
    "    df['rewrite_titles'] = df['rewrite_titles'].apply(lambda x: [i.lower() for i in x.split(';;')] )\n",
    "    \n",
    "    pos_lists = []\n",
    "    for userindex, (pos_lis, rewrite_title_lis) in enumerate(zip(df['posnewID'].values.tolist(), df['rewrite_titles'].values.tolist())):\n",
    "        for pos, rewrite_title in zip(pos_lis, rewrite_title_lis):\n",
    "            if rewrite_title.strip() == '':\n",
    "                continue\n",
    "            else:\n",
    "                pos_lists.append([userindex, pos, rewrite_title])\n",
    "    \n",
    "    return df['clicknewsID'].values.tolist(), pos_lists"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "source": [
    "%time TestUsers, TestSamples = parse_test_user(testfilename,news_index)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 60 ms, sys: 0 ns, total: 60 ms\n",
      "Wall time: 83.3 ms\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "source": [
    "with open('../../data2/TestUsers.pkl', 'wb') as f:\n",
    "    pickle.dump(TestUsers, f)\n",
    "with open('../../data2/TestSamples.pkl', 'wb') as f:\n",
    "    pickle.dump(TestSamples, f)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "source": [
    "TestSamples[0]"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "[0,\n",
       " 14111,\n",
       " \"legal battle looms over trump epa's rule change of obama's clean power plan rule\"]"
      ]
     },
     "metadata": {},
     "execution_count": 32
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "f1ea14d0bd9c7a0f4f2d255c0662c7e1119328c505d1c17d9d8f159415dfcf69"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}